{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "CyCLIP-Example.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU",
    "gpuClass": "standard"
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q7IM3Yl5b4Qr"
      },
      "outputs": [],
      "source": [
        "!git clone https://github.com/goel-shashank/CyCLIP.git\n",
        "!pip install ftfy"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "%cd CyCLIP/"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "POuRQtk0coPt",
        "outputId": "1f411fbb-e3b6-4e7d-ca03-53b7ff261e78"
      },
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "/content/CyCLIP\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import requests\n",
        "from io import BytesIO\n",
        "from PIL import Image, ImageFile\n",
        "\n",
        "from pkgs.openai.clip import load as load_model\n",
        "\n",
        "ImageFile.LOAD_TRUNCATED_IMAGES = True"
      ],
      "metadata": {
        "id": "4HALKOrRcs1g"
      },
      "execution_count": 14,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "device = 'cuda' "
      ],
      "metadata": {
        "id": "tOS0UG2dczTf"
      },
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def get_inputs(image, caption):\n",
        "    captions     = processor.process_text(caption)\n",
        "    pixel_values = processor.process_image(image.convert(\"RGB\"))\n",
        "    return captions['input_ids'].to(device), captions['attention_mask'].to(device), pixel_values.to(device).unsqueeze(0)"
      ],
      "metadata": {
        "id": "jJcM2Xcmc8Xq"
      },
      "execution_count": 15,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "## pretrained = True loads the original OpenAI CLIP model trained on 400M image-text pairs\n",
        "model, processor = load_model(name = 'RN50', pretrained = False)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "mTpvpf6yc9Ji",
        "outputId": "cb4f9a28-59e4-4e7b-ae5e-c6b90c21852c"
      },
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|████████████████████████████████████████| 256M/256M [00:02<00:00, 116MiB/s]\n",
            "/usr/local/lib/python3.7/dist-packages/torchvision/transforms/transforms.py:333: UserWarning: Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.\n",
            "  \"Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. \"\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "model.to(device)"
      ],
      "metadata": {
        "id": "jbaVse-EdGXr"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "## Replace with the location of the checkpoint \n",
        "## The link for checkpoints -- https://drive.google.com/drive/u/0/folders/1K0kPJZ3MA4KAdx3Fpq25dgW59wIf7M-x\n",
        "\n",
        "checkpoint = '/content/drive/MyDrive/Spring22/checkpoints/cyclip.pt/best.pt'"
      ],
      "metadata": {
        "id": "Z0hKVUStdeLs"
      },
      "execution_count": 9,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "state_dict = torch.load(checkpoint, map_location = device)[\"state_dict\"]\n",
        "if(next(iter(state_dict.items()))[0].startswith(\"module\")):\n",
        "    state_dict = {key[len(\"module.\"):]: value for key, value in state_dict.items()}\n",
        "model.load_state_dict(state_dict)\n",
        "model.eval()  "
      ],
      "metadata": {
        "id": "_XsJ4FEddr0l"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "url = 'https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg?auto=compress&cs=tinysrgb&w=1260&h=750&dpr=2'"
      ],
      "metadata": {
        "id": "fTE3ssYpdtTn"
      },
      "execution_count": 16,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "response = requests.get(url)\n",
        "img = Image.open(BytesIO(response.content))"
      ],
      "metadata": {
        "id": "Hk1B5Hbyfseo"
      },
      "execution_count": 17,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "caption1 = 'a photo of dogs'\n",
        "caption2 = 'a photo of cats'"
      ],
      "metadata": {
        "id": "gHJ_r6E2f6Vo"
      },
      "execution_count": 26,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def clipscore(model, output):\n",
        "  return (model.logit_scale.exp() * output.image_embeds @ output.text_embeds.t()).item()"
      ],
      "metadata": {
        "id": "6yQdAJmHf7YY"
      },
      "execution_count": 27,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "input1 = get_inputs(img, caption1)\n",
        "input2 = get_inputs(img, caption2)\n",
        "output1 = model(input_ids = input1[0], attention_mask = input1[1], pixel_values = input1[2])\n",
        "output2 = model(input_ids = input2[0], attention_mask = input2[1], pixel_values = input2[2])\n",
        "\n",
        "print(clipscore(model, output1))\n",
        "print(clipscore(model, output2))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ctuT4MBWgZMP",
        "outputId": "bfd50611-2a34-4c7d-a7b4-32f76da6b9ae"
      },
      "execution_count": 28,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "41.97170639038086\n",
            "34.910865783691406\n"
          ]
        }
      ]
    }
  ]
}